import json
import time

import numpy as np
import pyrealsense2 as rs

DS5_product_ids = ["0AD1", "0AD2", "0AD3", "0AD4", "0AD5", "0AF6", "0AFE", "0AFF", "0B00", "0B01", "0B03", "0B07","0B3A"]

def find_device_that_supports_advanced_mode():
    ctx = rs.context()
    ds5_dev = rs.device()
    devices = ctx.query_devices()
    for dev in devices:
        if dev.supports(rs.camera_info.product_id) and str(dev.get_info(rs.camera_info.product_id)) in DS5_product_ids:
            if dev.supports(rs.camera_info.name):
                print("Found device that supports advanced mode:", dev.get_info(rs.camera_info.name))
            return dev
    raise Exception("No device that supports advanced mode was found")


def set_dev_preset(json_path):
    """ Set camera parameters with a json preset file

    :param json_path: json path to the configuration file
    :return: None
    """
    try:
        dev = find_device_that_supports_advanced_mode()
        advnc_mode = rs.rs400_advanced_mode(dev)
        print("Advanced mode is", "enabled" if advnc_mode.is_enabled() else "disabled")

        # Loop until we successfully enable advanced mode
        while not advnc_mode.is_enabled():
            print("Trying to enable advanced mode...")
            advnc_mode.toggle_advanced_mode(True)
            # At this point the device will disconnect and re-connect.
            print("Sleeping for 5 seconds...")
            time.sleep(5)
            # The 'dev' object will become invalid and we need to initialize it again
            dev = find_device_that_supports_advanced_mode()
            advnc_mode = rs.rs400_advanced_mode(dev)
            print("Advanced mode is", "enabled" if advnc_mode.is_enabled() else "disabled")

        with open(json_path, "r") as file:
            as_json_object = json.load(file)



        if type(next(iter(as_json_object))) != str:
            as_json_object = {k.encode('utf-8'): v.encode("utf-8") for k, v in as_json_object.items()}

        json_string = str(as_json_object).replace("'", '\"')
        print(json_string)
        advnc_mode.load_json(json_string)

    except Exception as e:
        print(e)
        pass


def get_images_from_pipeline(pipeline, align, timeout=5000):
    ''' Retrieve a pair of aligned color and depth images from pipeline

    :param pipeline: pipeline object from retrieve_aligned_pipeline
    :param align: align object from retrieve_aligned_pipeline
    :param timeout: time in ms after which this throws ???Exception
    :return: color_image - h x w x 3 uint8 np array
             depth_image - h x w float np array
    '''
    frames = pipeline.wait_for_frames(timeout)
    aligned_frames = align.process(frames)

    depth_frame = aligned_frames.get_depth_frame()
    color_frame = aligned_frames.get_color_frame()

    depth_image = np.asanyarray(depth_frame.get_data())
    color_image = np.asanyarray(color_frame.get_data()).astype('uint8')

    return color_image, depth_image


def retrieve_aligned_pipeline(filename=None, record_bag=False, width=640, height=480, fps=30, verbose=False):
    ''' Returns pipeline and align objects to be used with get_images_from_pipeline to obtain images

    :param filename: filename if reading from .bag file, None for live camera
    :param width: width of the frame - if a bag file is loaded this must match it
    :param height: height of the frame - if a bag file is loaded this must match it
    :param verbose: if true this prints out checkpoints in capture
    :return: pipeline, align
    '''
    cfg = rs.config()
    if filename is not None and not record_bag:
        cfg.enable_device_from_file(filename, repeat_playback=False)
    cfg.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
    cfg.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
    if filename is not None and record_bag:
        cfg.enable_record_to_file(filename)

    if verbose:
        print("Initializing pipeline!")
        pipeline = rs.pipeline()
        print("Starting pipeline!")
        p_cfg = pipeline.start(cfg)
        print("Pipeline started!")

        profile = p_cfg.get_stream(rs.stream.color)
        intr = profile.as_video_stream_profile().get_intrinsics()
        print("Color stream intrinsics: {}".format(intr))

        profile = p_cfg.get_stream(rs.stream.depth)
        intr = profile.as_video_stream_profile().get_intrinsics()
        print("Color stream intrinsics: {}".format(intr))
    else:
        pipeline = rs.pipeline()
        p_cfg = pipeline.start(cfg)

    # This is necessary to get all frames from the bag file
    if filename is not None and not record_bag:
        playback = p_cfg.get_device().as_playback()
        playback.set_real_time(False)

    align_to = rs.stream.color
    align = rs.align(align_to)

    return pipeline, align